{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FhGuhbZ6M5tl"
      },
      "source": [
        "##### Copyright 2022 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "AwOEIRJC6Une"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EIdT9iu_Z4Rb"
      },
      "source": [
        "# Multilayer perceptrons for digit recognition with Core APIs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bBIlTPscrIT9"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/core/mlp_core\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/core/mlp_core.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/core/mlp_core.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/core/mlp_core.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SjAxxRpBzVYg"
      },
      "source": [
        "This notebook uses the [TensorFlow Core low-level APIs](https://www.tensorflow.org/guide/core) to build an end-to-end machine learning workflow for handwritten digit classification with [multilayer perceptrons](https://developers.google.com/machine-learning/crash-course/introduction-to-neural-networks/anatomy) and the [MNIST dataset](http://yann.lecun.com/exdb/mnist). Visit the [Core APIs overview](https://www.tensorflow.org/guide/core) to learn more about TensorFlow Core and its intended use cases."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GHVMVIFHSzl1"
      },
      "source": [
        "## Multilayer perceptron (MLP) overview\n",
        "\n",
        "The Multilayer Perceptron (MLP) is a type of feedforward neural network used to approach [multiclass classification](https://developers.google.com/machine-learning/crash-course/multi-class-neural-networks/video-lecture) problems. Before building an MLP, it is crucial to understand the concepts of perceptrons, layers, and activation functions.\n",
        "\n",
        "Multilayer Perceptrons are made up of functional units called perceptrons. The equation of a  perceptron is as follows:\n",
        "\n",
        "$$Z = \\vec{w}⋅\\mathrm{X} + b$$\n",
        "\n",
        "where\n",
        "\n",
        "* $Z$: perceptron output\n",
        "* $\\mathrm{X}$: feature matrix\n",
        "* $\\vec{w}$: weight vector\n",
        "* $b$: bias\n",
        "\n",
        "When these perceptrons are stacked, they form structures called dense layers which can then be connected to build a neural network. A dense layer's equation is similar to that of a perceptron's but uses a weight matrix and a bias vector instead: \n",
        "\n",
        "$$Z = \\mathrm{W}⋅\\mathrm{X} + \\vec{b}$$\n",
        "\n",
        "where\n",
        "\n",
        "* $Z$: dense layer output\n",
        "* $\\mathrm{X}$: feature matrix\n",
        "* $\\mathrm{W}$: weight matrix\n",
        "* $\\vec{b}$: bias vector\n",
        "\n",
        "\n",
        "In an MLP, multiple dense layers are connected in such a way that the outputs of one layer are fully connected to the inputs of the next layer. Adding non-linear activation functions to the outputs of dense layers can help the MLP classifier learn complex decision boundaries and generalize well to unseen data."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nchsZfwEVtVs"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Import TensorFlow, [pandas](https://pandas.pydata.org), [Matplotlib](https://matplotlib.org) and [seaborn](https://seaborn.pydata.org) to get started."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mSfgqmwBagw_"
      },
      "outputs": [],
      "source": [
        "# Use seaborn for countplot.\n",
        "!pip install -q seaborn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1rRo8oNqZ-Rj"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import matplotlib\n",
        "from matplotlib import pyplot as plt\n",
        "import seaborn as sns\n",
        "import tempfile\n",
        "import os\n",
        "# Preset Matplotlib figure sizes.\n",
        "matplotlib.rcParams['figure.figsize'] = [9, 6]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9xQKvCJ85kCQ"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "print(tf.__version__)\n",
        "# Set random seed for reproducible results \n",
        "tf.random.set_seed(22)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F_72b0LCNbjx"
      },
      "source": [
        "## Load the data\n",
        "\n",
        "This tutorial uses the [MNIST dataset](http://yann.lecun.com/exdb/mnist), and demonstrates how to build an MLP model that can classify handwritten digits. The dataset is available from [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/mnist).\n",
        "\n",
        "Split the MNIST dataset into training, validation, and testing sets. The validation set can be used to gauge the model's generalizability during training so that the test set can serve as a final unbiased estimator for the model's performance.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Uiuh0B098_3p"
      },
      "outputs": [],
      "source": [
        "train_data, val_data, test_data = tfds.load(\"mnist\", \n",
        "                                            split=['train[10000:]', 'train[0:10000]', 'test'],\n",
        "                                            batch_size=128, as_supervised=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X9uN3Lf6ANtn"
      },
      "source": [
        "The MNIST dataset consists of handwritten digits and their corresponding true labels. Visualize a couple of examples below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6V8hSqJ7AMjQ"
      },
      "outputs": [],
      "source": [
        "x_viz, y_viz = tfds.load(\"mnist\", split=['train[:1500]'], batch_size=-1, as_supervised=True)[0]\n",
        "x_viz = tf.squeeze(x_viz, axis=3)\n",
        "\n",
        "for i in range(9):\n",
        "    plt.subplot(3,3,1+i)\n",
        "    plt.axis('off')\n",
        "    plt.imshow(x_viz[i], cmap='gray')\n",
        "    plt.title(f\"True Label: {y_viz[i]}\")\n",
        "    plt.subplots_adjust(hspace=.5)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bRald9dSE4qS"
      },
      "source": [
        "Also review the distribution of digits in the training data to verify that each class is well represented in the dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rj3K4XgQE7qR"
      },
      "outputs": [],
      "source": [
        "sns.countplot(x=y_viz.numpy());\n",
        "plt.xlabel('Digits')\n",
        "plt.title(\"MNIST Digit Distribution\");"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x_Wt4bDx_BRV"
      },
      "source": [
        "## Preprocess the data\n",
        "\n",
        "First, reshape the feature matrices to be 2-dimensional by flattening the images. Next, rescale the data so that the pixel values of [0,255] fit into the range of [0,1]. This step ensures that the input pixels have similar distributions and helps with training convergence."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JSyCm2V2_AvI"
      },
      "outputs": [],
      "source": [
        "def preprocess(x, y):\n",
        "  # Reshaping the data\n",
        "  x = tf.reshape(x, shape=[-1, 784])\n",
        "  # Rescaling the data\n",
        "  x = x/255\n",
        "  return x, y\n",
        "\n",
        "train_data, val_data = train_data.map(preprocess), val_data.map(preprocess)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6o3CrycBXA2s"
      },
      "source": [
        "## Build the MLP \n",
        "\n",
        "Start by visualizing the [ReLU](https://developers.google.com/machine-learning/glossary#ReLU) and [Softmax](https://developers.google.com/machine-learning/glossary#softmax) activation functions. Both functions are available in `tf.nn.relu` and `tf.nn.softmax` respectively. The ReLU is a non-linear activation function that outputs the input if it is positive and 0 otherwise: \n",
        "\n",
        "$$\\text{ReLU}(X) = max(0, X)$$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hYunzt3UyT9G"
      },
      "outputs": [],
      "source": [
        "x = tf.linspace(-2, 2, 201)\n",
        "x = tf.cast(x, tf.float32)\n",
        "plt.plot(x, tf.nn.relu(x));\n",
        "plt.xlabel('x')\n",
        "plt.ylabel('ReLU(x)')\n",
        "plt.title('ReLU activation function');"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fuGrM9jMwsRM"
      },
      "source": [
        "The softmax activation function is a normalized exponential function that converts $m$ real numbers into a probability distribution with $m$ outcomes/classes. This is useful for predicting class probabilities from a neural network's output:\n",
        "\n",
        "$$\\text{Softmax}(X) = \\frac{e^{X}}{\\sum_{i=1}^{m}e^{X_i}}$$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fVM8pvhWwuwI"
      },
      "outputs": [],
      "source": [
        "x = tf.linspace(-4, 4, 201)\n",
        "x = tf.cast(x, tf.float32)\n",
        "plt.plot(x, tf.nn.softmax(x, axis=0));\n",
        "plt.xlabel('x')\n",
        "plt.ylabel('Softmax(x)')\n",
        "plt.title('Softmax activation function');"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OHW6Yvg2yS6H"
      },
      "source": [
        "### The dense layer\n",
        "\n",
        "Create a class for the dense layer. By definition, the outputs of one layer are fully connected to the inputs of the next layer in an MLP. Therefore, the input dimension for a dense layer can be inferred based on the output dimension of its previous layer and does not need to be specified upfront during its initialization. The weights should also be initialized properly to prevent activation outputs from becoming too large or small. One of the most popular weight initialization methods is the Xavier scheme, where each element of the weight matrix is sampled in the following manner:\n",
        "\n",
        "$$W_{ij} \\sim \\text{Uniform}(-\\frac{\\sqrt{6}}{\\sqrt{n + m}},\\frac{\\sqrt{6}}{\\sqrt{n + m}})$$\n",
        "\n",
        "The bias vector can be initialized to zeros."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "re1SSFyBdMrS"
      },
      "outputs": [],
      "source": [
        "def xavier_init(shape):\n",
        "  # Computes the xavier initialization values for a weight matrix\n",
        "  in_dim, out_dim = shape\n",
        "  xavier_lim = tf.sqrt(6.)/tf.sqrt(tf.cast(in_dim + out_dim, tf.float32))\n",
        "  weight_vals = tf.random.uniform(shape=(in_dim, out_dim), \n",
        "                                  minval=-xavier_lim, maxval=xavier_lim, seed=22)\n",
        "  return weight_vals"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "otDFX4u6e6ml"
      },
      "source": [
        "The Xavier initialization method can also be implemented with `tf.keras.initializers.GlorotUniform`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IM0yJos25FG5"
      },
      "outputs": [],
      "source": [
        "class DenseLayer(tf.Module):\n",
        "\n",
        "  def __init__(self, out_dim, weight_init=xavier_init, activation=tf.identity):\n",
        "    # Initialize the dimensions and activation functions\n",
        "    self.out_dim = out_dim\n",
        "    self.weight_init = weight_init\n",
        "    self.activation = activation\n",
        "    self.built = False\n",
        "\n",
        "  def __call__(self, x):\n",
        "    if not self.built:\n",
        "      # Infer the input dimension based on first call\n",
        "      self.in_dim = x.shape[1]\n",
        "      # Initialize the weights and biases\n",
        "      self.w = tf.Variable(self.weight_init(shape=(self.in_dim, self.out_dim)))\n",
        "      self.b = tf.Variable(tf.zeros(shape=(self.out_dim,)))\n",
        "      self.built = True\n",
        "    # Compute the forward pass\n",
        "    z = tf.add(tf.matmul(x, self.w), self.b)\n",
        "    return self.activation(z)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X-7MzpjgyHg6"
      },
      "source": [
        "Next, build a class for the MLP model that executes layers sequentially.\n",
        "Remember that the model variables are only available after the first sequence of dense layer calls due to dimension inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6XisRWiCyHAb"
      },
      "outputs": [],
      "source": [
        "class MLP(tf.Module):\n",
        "\n",
        "  def __init__(self, layers):\n",
        "    self.layers = layers\n",
        "   \n",
        "  @tf.function\n",
        "  def __call__(self, x, preds=False): \n",
        "    # Execute the model's layers sequentially\n",
        "    for layer in self.layers:\n",
        "      x = layer(x)\n",
        "    return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "luXKup-43nd7"
      },
      "source": [
        "Initialize a MLP model with the following architecture:\n",
        "\n",
        "- Forward Pass: ReLU(784 x 700) x ReLU(700 x 500) x Softmax(500 x 10)\n",
        "\n",
        "The softmax activation function does not need to be applied by the MLP. It is computed separately in the loss and prediction functions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VmlACuki3oPi"
      },
      "outputs": [],
      "source": [
        "hidden_layer_1_size = 700\n",
        "hidden_layer_2_size = 500\n",
        "output_size = 10\n",
        "\n",
        "mlp_model = MLP([\n",
        "    DenseLayer(out_dim=hidden_layer_1_size, activation=tf.nn.relu),\n",
        "    DenseLayer(out_dim=hidden_layer_2_size, activation=tf.nn.relu),\n",
        "    DenseLayer(out_dim=output_size)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tyBATDoRmDkg"
      },
      "source": [
        "### Define the loss function\n",
        "\n",
        "The cross-entropy loss function is a great choice for multiclass classification problems since it measures the negative-log-likelihood of the data according to the model's probability predictions. The higher the probability assigned to the true class, the lower the loss. The equation for the cross-entropy loss is as follows:\n",
        "\n",
        "$$L = -\\frac{1}{n}\\sum_{i=1}^{n}\\sum_{i=j}^{n} {y_j}^{[i]}⋅\\log(\\hat{{y_j}}^{[i]})$$\n",
        "\n",
        "where\n",
        "\n",
        "* $\\underset{n\\times m}{\\hat{y}}$: a matrix of predicted class distributions\n",
        "* $\\underset{n\\times m}{y}$: a one hot encoded matrix of true classes\n",
        "\n",
        "The `tf.nn.sparse_softmax_cross_entropy_with_logits` function can be used to compute the cross-entropy loss. This function does not require the model's last layer to apply the softmax activation function nor does it require the class labels to be one hot encoded"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rskOYA7FVCwg"
      },
      "outputs": [],
      "source": [
        "def cross_entropy_loss(y_pred, y):\n",
        "  # Compute cross entropy loss with a sparse operation\n",
        "  sparse_ce = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=y_pred)\n",
        "  return tf.reduce_mean(sparse_ce)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BvWxED1km8jh"
      },
      "source": [
        "Write a basic accuracy function that calculates the proportion of correct classifications during training. In order to generate class predictions from softmax outputs, return the index that corresponds to the largest class probability. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jPJMWx2UgiBm"
      },
      "outputs": [],
      "source": [
        "def accuracy(y_pred, y):\n",
        "  # Compute accuracy after extracting class predictions\n",
        "  class_preds = tf.argmax(tf.nn.softmax(y_pred), axis=1)\n",
        "  is_equal = tf.equal(y, class_preds)\n",
        "  return tf.reduce_mean(tf.cast(is_equal, tf.float32))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JSiNRhTOnKZr"
      },
      "source": [
        "### Train the model\n",
        "\n",
        "Using an optimizer can result in significantly faster convergence compared to standard gradient descent. The Adam optimizer is implemented below. Visit the [Optimizers](https://www.tensorflow.org/guide/core/optimizers_core) guide to learn more about designing custom optimizers with TensorFlow Core."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iGIBDk3cAv6a"
      },
      "outputs": [],
      "source": [
        "class Adam:\n",
        "\n",
        "    def __init__(self, learning_rate=1e-3, beta_1=0.9, beta_2=0.999, ep=1e-7):\n",
        "      # Initialize optimizer parameters and variable slots\n",
        "      self.beta_1 = beta_1\n",
        "      self.beta_2 = beta_2\n",
        "      self.learning_rate = learning_rate\n",
        "      self.ep = ep\n",
        "      self.t = 1.\n",
        "      self.v_dvar, self.s_dvar = [], []\n",
        "      self.built = False\n",
        "      \n",
        "    def apply_gradients(self, grads, vars):\n",
        "      # Initialize variables on the first call\n",
        "      if not self.built:\n",
        "        for var in vars:\n",
        "          v = tf.Variable(tf.zeros(shape=var.shape))\n",
        "          s = tf.Variable(tf.zeros(shape=var.shape))\n",
        "          self.v_dvar.append(v)\n",
        "          self.s_dvar.append(s)\n",
        "        self.built = True\n",
        "      # Update the model variables given their gradients\n",
        "      for i, (d_var, var) in enumerate(zip(grads, vars)):\n",
        "        self.v_dvar[i].assign(self.beta_1*self.v_dvar[i] + (1-self.beta_1)*d_var)\n",
        "        self.s_dvar[i].assign(self.beta_2*self.s_dvar[i] + (1-self.beta_2)*tf.square(d_var))\n",
        "        v_dvar_bc = self.v_dvar[i]/(1-(self.beta_1**self.t))\n",
        "        s_dvar_bc = self.s_dvar[i]/(1-(self.beta_2**self.t))\n",
        "        var.assign_sub(self.learning_rate*(v_dvar_bc/(tf.sqrt(s_dvar_bc) + self.ep)))\n",
        "      self.t += 1.\n",
        "      return "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "osEK3rqpYfKd"
      },
      "source": [
        "Now, write a custom training loop that updates the MLP parameters with mini-batch gradient descent. Using mini-batches for training provides both memory efficiency and faster convergence."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CJLeY2ao1aw6"
      },
      "outputs": [],
      "source": [
        "def train_step(x_batch, y_batch, loss, acc, model, optimizer):\n",
        "  # Update the model state given a batch of data\n",
        "  with tf.GradientTape() as tape:\n",
        "    y_pred = model(x_batch)\n",
        "    batch_loss = loss(y_pred, y_batch)\n",
        "  batch_acc = acc(y_pred, y_batch)\n",
        "  grads = tape.gradient(batch_loss, model.variables)\n",
        "  optimizer.apply_gradients(grads, model.variables)\n",
        "  return batch_loss, batch_acc\n",
        "\n",
        "def val_step(x_batch, y_batch, loss, acc, model):\n",
        "  # Evaluate the model on given a batch of validation data\n",
        "  y_pred = model(x_batch)\n",
        "  batch_loss = loss(y_pred, y_batch)\n",
        "  batch_acc = acc(y_pred, y_batch)\n",
        "  return batch_loss, batch_acc"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oC85kuZgmh3q"
      },
      "outputs": [],
      "source": [
        "def train_model(mlp, train_data, val_data, loss, acc, optimizer, epochs):\n",
        "  # Initialize data structures\n",
        "  train_losses, train_accs = [], []\n",
        "  val_losses, val_accs = [], []\n",
        "\n",
        "  # Format training loop and begin training\n",
        "  for epoch in range(epochs):\n",
        "    batch_losses_train, batch_accs_train = [], []\n",
        "    batch_losses_val, batch_accs_val = [], []\n",
        "\n",
        "    # Iterate over the training data\n",
        "    for x_batch, y_batch in train_data:\n",
        "      # Compute gradients and update the model's parameters\n",
        "      batch_loss, batch_acc = train_step(x_batch, y_batch, loss, acc, mlp, optimizer)\n",
        "      # Keep track of batch-level training performance\n",
        "      batch_losses_train.append(batch_loss)\n",
        "      batch_accs_train.append(batch_acc)\n",
        "\n",
        "    # Iterate over the validation data\n",
        "    for x_batch, y_batch in val_data:\n",
        "      batch_loss, batch_acc = val_step(x_batch, y_batch, loss, acc, mlp)\n",
        "      batch_losses_val.append(batch_loss)\n",
        "      batch_accs_val.append(batch_acc)\n",
        "\n",
        "    # Keep track of epoch-level model performance\n",
        "    train_loss, train_acc = tf.reduce_mean(batch_losses_train), tf.reduce_mean(batch_accs_train)\n",
        "    val_loss, val_acc = tf.reduce_mean(batch_losses_val), tf.reduce_mean(batch_accs_val)\n",
        "    train_losses.append(train_loss)\n",
        "    train_accs.append(train_acc)\n",
        "    val_losses.append(val_loss)\n",
        "    val_accs.append(val_acc)\n",
        "    print(f\"Epoch: {epoch}\")\n",
        "    print(f\"Training loss: {train_loss:.3f}, Training accuracy: {train_acc:.3f}\")\n",
        "    print(f\"Validation loss: {val_loss:.3f}, Validation accuracy: {val_acc:.3f}\")\n",
        "  return train_losses, train_accs, val_losses, val_accs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FvbfXlN5lwwB"
      },
      "source": [
        "Train the MLP model for 10 epochs with batch size of 128. Hardware accelerators like GPUs or TPUs can also help speed up training time. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zPlT8QfxptYl"
      },
      "outputs": [],
      "source": [
        "train_losses, train_accs, val_losses, val_accs = train_model(mlp_model, train_data, val_data, \n",
        "                                                             loss=cross_entropy_loss, acc=accuracy,\n",
        "                                                             optimizer=Adam(), epochs=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j_RVmt43G12R"
      },
      "source": [
        "### Performance evaluation\n",
        "\n",
        "Start by writing a plotting function to visualize the model's loss and accuracy during training. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VXTCYVtNDjAM"
      },
      "outputs": [],
      "source": [
        "def plot_metrics(train_metric, val_metric, metric_type):\n",
        "  # Visualize metrics vs training Epochs\n",
        "  plt.figure()\n",
        "  plt.plot(range(len(train_metric)), train_metric, label = f\"Training {metric_type}\")\n",
        "  plt.plot(range(len(val_metric)), val_metric, label = f\"Validation {metric_type}\")\n",
        "  plt.xlabel(\"Epochs\")\n",
        "  plt.ylabel(metric_type)\n",
        "  plt.legend()\n",
        "  plt.title(f\"{metric_type} vs Training epochs\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DC-qIvZbHo0G"
      },
      "outputs": [],
      "source": [
        "plot_metrics(train_losses, val_losses, \"cross entropy loss\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P-w2xk2PIDve"
      },
      "outputs": [],
      "source": [
        "plot_metrics(train_accs, val_accs, \"accuracy\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tbrJJaFrD_XR"
      },
      "source": [
        "## Save and load the model\n",
        "\n",
        "Start by making an export module that takes in raw data and performs the following operations:\n",
        "- Data preprocessing \n",
        "- Probability prediction\n",
        "- Class prediction"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1sszfWuJJZoo"
      },
      "outputs": [],
      "source": [
        "class ExportModule(tf.Module):\n",
        "  def __init__(self, model, preprocess, class_pred):\n",
        "    # Initialize pre and postprocessing functions\n",
        "    self.model = model\n",
        "    self.preprocess = preprocess\n",
        "    self.class_pred = class_pred\n",
        "\n",
        "  @tf.function(input_signature=[tf.TensorSpec(shape=[None, None, None, None], dtype=tf.uint8)]) \n",
        "  def __call__(self, x):\n",
        "    # Run the ExportModule for new data points\n",
        "    x = self.preprocess(x)\n",
        "    y = self.model(x)\n",
        "    y = self.class_pred(y)\n",
        "    return y "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p8x6gjTDVi5d"
      },
      "outputs": [],
      "source": [
        "def preprocess_test(x):\n",
        "  # The export module takes in unprocessed and unlabeled data\n",
        "  x = tf.reshape(x, shape=[-1, 784])\n",
        "  x = x/255\n",
        "  return x\n",
        "\n",
        "def class_pred_test(y):\n",
        "  # Generate class predictions from MLP output\n",
        "  return tf.argmax(tf.nn.softmax(y), axis=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vu9H5STrJzdo"
      },
      "source": [
        "This export module can now be saved with the `tf.saved_model.save` function. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fN9pPBQTKTe3"
      },
      "outputs": [],
      "source": [
        "mlp_model_export = ExportModule(model=mlp_model,\n",
        "                                preprocess=preprocess_test,\n",
        "                                class_pred=class_pred_test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "idS7rQKbKwRS"
      },
      "outputs": [],
      "source": [
        "models = tempfile.mkdtemp()\n",
        "save_path = os.path.join(models, 'mlp_model_export')\n",
        "tf.saved_model.save(mlp_model_export, save_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_zZxO8iqBGZ-"
      },
      "source": [
        "Load the saved model with `tf.saved_model.load` and examine its performance on the unseen test data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W5cwBTUqxldW"
      },
      "outputs": [],
      "source": [
        "mlp_loaded = tf.saved_model.load(save_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bmv0u6j_b5OC"
      },
      "outputs": [],
      "source": [
        "def accuracy_score(y_pred, y):\n",
        "  # Generic accuracy function\n",
        "  is_equal = tf.equal(y_pred, y)\n",
        "  return tf.reduce_mean(tf.cast(is_equal, tf.float32))\n",
        "\n",
        "x_test, y_test = tfds.load(\"mnist\", split=['test'], batch_size=-1, as_supervised=True)[0]\n",
        "test_classes = mlp_loaded(x_test)\n",
        "test_acc = accuracy_score(test_classes, y_test)\n",
        "print(f\"Test Accuracy: {test_acc:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j5t9vgv_ciQ_"
      },
      "source": [
        "The model does a great job of classifying handwritten digits in the training dataset and also generalizes well to unseen data. Now, examine the model's class-wise accuracy to ensure good performance for each digit. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UD8YiC1Vfeyp"
      },
      "outputs": [],
      "source": [
        "print(\"Accuracy breakdown by digit:\")\n",
        "print(\"---------------------------\")\n",
        "label_accs = {}\n",
        "for label in range(10):\n",
        "  label_ind = (y_test == label)\n",
        "  # extract predictions for specific true label\n",
        "  pred_label = test_classes[label_ind]\n",
        "  labels = y_test[label_ind]\n",
        "  # compute class-wise accuracy\n",
        "  label_accs[accuracy_score(pred_label, labels).numpy()] = label\n",
        "for key in sorted(label_accs):\n",
        "  print(f\"Digit {label_accs[key]}: {key:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rcykuJFhdGb0"
      },
      "source": [
        "It looks like the model struggles with some digits a little more than others which is quite common in many multiclass classification problems. As a final exercise, plot a confusion matrix of the model's predictions and its corresponding true labels to gather more class-level insights. Sklearn and seaborn have functions for generating and visualizing confusion matrices. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JqCaqPwwh1tN"
      },
      "outputs": [],
      "source": [
        "import sklearn.metrics as sk_metrics\n",
        "\n",
        "def show_confusion_matrix(test_labels, test_classes):\n",
        "  # Compute confusion matrix and normalize\n",
        "  plt.figure(figsize=(10,10))\n",
        "  confusion = sk_metrics.confusion_matrix(test_labels.numpy(), \n",
        "                                          test_classes.numpy())\n",
        "  confusion_normalized = confusion / confusion.sum(axis=1, keepdims=True)\n",
        "  axis_labels = range(10)\n",
        "  ax = sns.heatmap(\n",
        "      confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,\n",
        "      cmap='Blues', annot=True, fmt='.4f', square=True)\n",
        "  plt.title(\"Confusion matrix\")\n",
        "  plt.ylabel(\"True label\")\n",
        "  plt.xlabel(\"Predicted label\")\n",
        "\n",
        "show_confusion_matrix(y_test, test_classes)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JT-WA7GVda6d"
      },
      "source": [
        "Class-level insights can help identify reasons for misclassifications and improve model performance in future training cycles."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VFLfEH4ManbW"
      },
      "source": [
        "## Conclusion\n",
        "\n",
        "This notebook introduced a few techniques to handle a multiclass classification problem with an [MLP](https://developers.google.com/machine-learning/crash-course/multi-class-neural-networks/softmax). Here are a few more tips that may help:\n",
        "\n",
        "- The [TensorFlow Core APIs](https://www.tensorflow.org/guide/core) can be used to build machine learning workflows with high levels of configurability\n",
        "- Initialization schemes can help prevent model parameters from vanishing or exploding during training.\n",
        "- Overfitting is another common problem for neural networks, though it wasn't a problem for this tutorial. Visit the [Overfit and underfit](overfit_and_underfit.ipynb) tutorial for more help with this.\n",
        "\n",
        "For more examples of using the TensorFlow Core APIs, check out the [guide](https://www.tensorflow.org/guide/core). If you want to learn more about loading and preparing data, see the tutorials on [image data loading](https://www.tensorflow.org/tutorials/load_data/images) or [CSV data loading](https://www.tensorflow.org/tutorials/load_data/csv)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "FhGuhbZ6M5tl"
      ],
      "name": "mlp_core.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
